import torch
import torch.nn.functional as F
import rl_utils


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x



class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, epochs, para_GAE_lmbda, para_PPO_clip, discount_factor, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)
        
        self.epochs = epochs  # 一条序列的数据用来训练轮数
        self.discount_factor = discount_factor
        self.para_GAE_lmbda = para_GAE_lmbda
        self.para_PPO_clip = para_PPO_clip  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
        
        td_target = rewards + self.discount_factor * self.critic(next_states) * (1 - dones)
        # 时序差分误差
        td_delta = td_target - self.critic(states)
        # 根据策略\theta '的优势
        advantage = rl_utils.compute_advantage(self.discount_factor, self.para_GAE_lmbda,td_delta.cpu()).to(self.device)
        # states 下，根据策略\theta '，采取各个动作的概率的对数
        old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()

        for _ in range(self.epochs):
            # states 下，根据策略\theta，采取各个动作的概率的对数
            log_probs = torch.log(self.actor(states).gather(1, actions))
            # 采取各个动作的概率的比
            ratio = torch.exp(log_probs - old_log_probs)
            ratio_clamp = torch.clamp(ratio, 1 - self.para_PPO_clip, 1 + self.para_PPO_clip)  # 截断
            actor_loss = torch.mean(-torch.min(ratio * advantage, ratio_clamp * advantage))  # PPO损失函数
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()
            

